import numpy as np
from torch import optim
import mine
import torch


def train_vib_igl(num_epochs, num_batch, gamma, dataset, test_data, T_lr, d_lr, T_wd, d_wd, device,
                  iteration, test_each, sub_sample_num,
                  igl_policy, p_epoch, p_batch, p_lr, num_action, f_measure, r_decoder):
    training_curve = np.zeros(p_epoch // (2 * test_each) + 1)
    training_curve_decoder = np.zeros(num_epochs // (2 * test_each) + 1)

    # Reward decoder
    # Evaluation
    mse = test_decoder(test_data=test_data, r_decoder=r_decoder, flag=True)
    training_curve_decoder[0] += mse
    print('f-VI-IGL, Iteration: {}, Epoch [0/{}], Decoder_MSE: {:.4f}'.format(iteration, num_epochs, mse))
    decoder_optimizer = optim.AdamW(r_decoder.parameters(), lr=d_lr, weight_decay=d_wd)
    r_decoder.train()

    # Conditional MINE Estimator of I(Y;X,A|R)
    T2 = mine.CMINE(num_action=num_action).to(device)
    T2_optimizer = optim.AdamW(T2.parameters(), lr=T_lr, weight_decay=T_wd)
    T2.train()
    # MINE Estimator of I(R;X,A)
    T3 = mine.MINE(num_action=num_action).to(device)
    T3_optimizer = optim.AdamW(T3.parameters(), lr=T_lr, weight_decay=T_wd)
    T3.train()

    for epoch in range(1, num_epochs + 1):
        batch_idx = np.random.choice(len(dataset), num_batch, replace=False)
        r_list = []
        # Sample p(x,a|r) and p(y|r)
        reward1_index_list = []
        reward0_index_list = []
        average_p1 = torch.zeros((1, 1)).to(device)
        for i in range(num_batch):
            r1 = r_decoder(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], y=dataset[batch_idx[i]][2])
            average_p1 += r1
            for _ in range(sub_sample_num):
                if np.random.random() <= r1.item():
                    reward1_index_list.append(i)
                else:
                    reward0_index_list.append(i)
            r_list.append(r1)
        average_p1 /= num_batch

        # Objectives of T3 (I(R;X,A))
        term3 = torch.zeros((1, 1)).to(device)
        term4 = torch.zeros((1, 1)).to(device)
        # Objectives of T2 (I(Y;X,A|R))
        term1 = torch.zeros((1, 1)).to(device)
        term2 = torch.zeros((1, 1)).to(device)

        if f_measure == 'KL':
            for i in range(num_batch):
                reward1_prob = r_list[i]
                t11 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t10 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))
                if len(reward1_index_list):
                    random_xa1_idx = reward1_index_list[np.random.randint(len(reward1_index_list))]
                else:
                    random_xa1_idx = i
                if len(reward0_index_list):
                    random_xa0_idx = reward0_index_list[np.random.randint(len(reward0_index_list))]
                else:
                    random_xa0_idx = i
                t21 = T2(x=dataset[batch_idx[random_xa1_idx]][0], a=dataset[batch_idx[random_xa1_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t20 = T2(x=dataset[batch_idx[random_xa0_idx]][0], a=dataset[batch_idx[random_xa0_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))

                t31 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.ones((1, 1)).to(device))
                t30 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.zeros((1, 1)).to(device))

                # term 1 = E_{p(x,a,y)}\sum_{r=0,1}(p(r|x,a,y)* T_2(x,a,y,r))
                term1 += reward1_prob * t11 + (1 - reward1_prob) * t10
                # term 2 = E_{p(r)p(x,a|r)p(y)}[e^{T_2(x,a,r)-1}]
                term2 += average_p1 * torch.exp(t21 - 1) + (1 - average_p1) * torch.exp(t20 - 1)
                # term4 = E_{p(x,a,r)}[T_3(x,a,r)]
                term4 += reward1_prob * t31 + (1 - reward1_prob) * t30
                # term3 = E_{p(x,a)p(r)}[e^{T_3(x,a,r)-1}]
                term3 += average_p1 * torch.exp(t31 - 1) + (1 - average_p1) * torch.exp(t30 - 1)

        elif f_measure == 'PEARSON_KL':
            for i in range(num_batch):
                reward1_prob = r_list[i]
                t11 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t10 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))
                if len(reward1_index_list):
                    random_xa1_idx = reward1_index_list[np.random.randint(len(reward1_index_list))]
                else:
                    random_xa1_idx = i
                if len(reward0_index_list):
                    random_xa0_idx = reward0_index_list[np.random.randint(len(reward0_index_list))]
                else:
                    random_xa0_idx = i
                t21 = T2(x=dataset[batch_idx[random_xa1_idx]][0], a=dataset[batch_idx[random_xa1_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t20 = T2(x=dataset[batch_idx[random_xa0_idx]][0], a=dataset[batch_idx[random_xa0_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))

                t31 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.ones((1, 1)).to(device))
                t30 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.zeros((1, 1)).to(device))

                # term 1 = E_{p(x,a,y)}\sum_{r=0,1}(p(r|x,a,y)* T_2(x,a,y,r))
                term1 += reward1_prob * t11 + (1 - reward1_prob) * t10
                # term 2 = E_{p(r)p(x,a|r)p(y)}[T_2(x,a,y,r) ** 2 * 0.25 + T_2(x,a,y,r)]
                term2 += average_p1 * (0.25 * t21 ** 2 + t21) + (1 - average_p1) * (0.25 * t20 ** 2 + t20)
                # term4 = E_{p(x,a,r)}[T_3(x,a,r)]
                term4 += reward1_prob * t31 + (1 - reward1_prob) * t30
                # term3 = E_{p(x,a)p(r)}[e^{T_3(x,a,r)-1}]
                term3 += average_p1 * torch.exp(t31 - 1) + (1 - average_p1) * torch.exp(t30 - 1)

        elif f_measure == 'PEARSON':
            for i in range(num_batch):
                reward1_prob = r_list[i]
                t11 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t10 = T2(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))
                if len(reward1_index_list):
                    random_xa1_idx = reward1_index_list[np.random.randint(len(reward1_index_list))]
                else:
                    random_xa1_idx = i
                if len(reward0_index_list):
                    random_xa0_idx = reward0_index_list[np.random.randint(len(reward0_index_list))]
                else:
                    random_xa0_idx = i
                t21 = T2(x=dataset[batch_idx[random_xa1_idx]][0], a=dataset[batch_idx[random_xa1_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.ones((1, 1)).to(device))
                t20 = T2(x=dataset[batch_idx[random_xa0_idx]][0], a=dataset[batch_idx[random_xa0_idx]][1],
                         y=dataset[batch_idx[i]][2], r=torch.zeros((1, 1)).to(device))

                t31 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.ones((1, 1)).to(device))
                t30 = T3(x=dataset[batch_idx[i]][0], a=dataset[batch_idx[i]][1], r=torch.zeros((1, 1)).to(device))

                # term 1 = E_{p(x,a,y)}\sum_{r=0,1}(p(r|x,a,y)* T_2(x,a,y,r))
                term1 += reward1_prob * t11 + (1 - reward1_prob) * t10
                # term 2 = E_{p(r)p(x,a|r)p(y)}[T_2(x,a,y,r) ** 2 * 0.25 + T_2(x,a,y,r)]
                term2 += average_p1 * (0.25 * t21 ** 2 + t21) + (1 - average_p1) * (0.25 * t20 ** 2 + t20)
                # term4 = E_{p(x,a,r)}[T_3(x,a,r)]
                term4 += reward1_prob * t31 + (1 - reward1_prob) * t30
                # term3 = E_{p(x,a)}\sum_{r=0,1}(p(r) * [T_3(x,a,r) ** 2 * 0.25 + T_3(x,a,r)])
                term3 += average_p1 * (0.25 * t31 ** 2 + t31) + (1 - average_p1) * (0.25 * t30 ** 2 + t30)

        term1 /= num_batch
        term2 /= num_batch
        term3 /= num_batch
        term4 /= num_batch
        # f-VI-objective = I(Y;X,A|R) - (1 / gamma) * I(R;X,A)
        if gamma > 10:
            obj = term1 - term2
        else:
            obj = term1 - term2 - (1 / gamma) * (term4 - term3)

        if epoch % 2 == 1:
            if gamma <= 10:
                T3_optimizer.zero_grad()
                T2_optimizer.zero_grad()
                obj.backward()
                for p in T2.parameters():
                    p.grad.data.mul_(-1)
                # gradient correction
                for p in T3.parameters():
                    p.grad.data.mul_(gamma)
                T3_optimizer.step()
                T2_optimizer.step()
            else:
                T2_optimizer.zero_grad()
                obj.backward()
                for p in T2.parameters():
                    p.grad.data.mul_(-1)
                T2_optimizer.step()
        else:
            decoder_optimizer.zero_grad()
            obj.backward()
            decoder_optimizer.step()

        if epoch % (2 * test_each) == 0:
            # Reward decoder selection
            average_p1 = 0
            rewards = []
            for i in range(len(dataset)):
                r1 = r_decoder(x=dataset[i][0], a=dataset[i][1], y=dataset[i][2]).item()
                rewards.append(r1)
                average_p1 += r1
            average_p1 /= len(dataset)
            if average_p1 < 0.5:
                flag = True
            else:
                flag = False  # flip the rewards if the decoded return of V(\pi_b) is larger than 0.5

            # Prepare rewards list to train the policy
            if epoch == num_epochs:
                if not flag:
                    for i in range(len(rewards)):
                        rewards[i] = 1 - rewards[i]  # Flip the decoded reward

            mse = test_decoder(test_data=test_data, r_decoder=r_decoder, flag=flag)
            training_curve_decoder[epoch // (2 * test_each)] += mse
            print('VIB-IGL, Iteration: {}, Epoch [{}/{}], Decoder_MSE: {:.4f}'
                  .format(iteration, epoch, num_epochs, mse))

    # Train policy
    training_curve += train_policy(dataset=dataset,
                                   test_dataset=test_data,
                                   rewards_list=rewards,
                                   igl_policy=igl_policy,
                                   num_epochs=p_epoch,
                                   num_batch=p_batch,
                                   num_action=num_action,
                                   iteration=iteration,
                                   test_each=test_each,
                                   p_lr=p_lr,
                                   device=device)

    return training_curve, training_curve_decoder


# Train Policy
def train_policy(dataset, test_dataset, rewards_list, num_epochs, num_batch, num_action, iteration,
                 igl_policy, test_each, p_lr, device):
    policy_optimizer = optim.AdamW(igl_policy.parameters(), lr=p_lr, weight_decay=0.5)
    igl_policy.train()

    training_curve = np.zeros(num_epochs // (2 * test_each) + 1)

    value = 0
    for i in range(len(test_dataset)):
        context = test_dataset[i][0]
        context_idx = test_dataset[i][4]
        action_prob = igl_policy(context)
        value += action_prob[0][context_idx].data
    value /= len(test_dataset)
    training_curve[0] += value
    print('f-VI-IGL Policy, Iteration: {}, Epoch [0/{}], Value: {:.4f}'.format(iteration, num_epochs, value))

    for epoch in range(1, num_epochs + 1):
        cumulative_decoded_value = torch.zeros((1, 1)).to(device)
        batch_idx = np.random.choice(len(dataset), num_batch, replace=False)

        # Off-policy Evaluation
        for i in range(num_batch):
            action_prob = igl_policy(dataset[batch_idx[i]][0])
            sample_action_index = dataset[batch_idx[i]][5]
            cumulative_decoded_value += num_action * rewards_list[batch_idx[i]] * action_prob[0][sample_action_index]

        loss = -cumulative_decoded_value
        policy_optimizer.zero_grad()
        loss.backward()
        policy_optimizer.step()

        if epoch % (2 * test_each) == 0:
            # Evaluate the policy by accuracy
            value = 0
            for i in range(len(test_dataset)):
                context = test_dataset[i][0]
                context_idx = test_dataset[i][4]
                action_prob = igl_policy(context)
                value += action_prob[0][context_idx].data
            value /= len(test_dataset)
            training_curve[epoch // (2 * test_each)] += value
            print('f-VI-IGL Policy, Iteration: {}, Epoch [{}/{}], Value: {:.4f}'
                  .format(iteration, epoch, num_epochs, value))

    return training_curve


# Test the reward decoder
def test_decoder(test_data, r_decoder, flag):
    loss = 0.
    if flag:
        for i in range(len(test_data)):
            loss += (r_decoder(x=test_data[i][0], a=test_data[i][1], y=test_data[i][2]).item() - test_data[i][3]) ** 2
    else:
        for i in range(len(test_data)):
            loss += (1 - r_decoder(x=test_data[i][0], a=test_data[i][1], y=test_data[i][2]).item() - test_data[i][3]) \
                    ** 2
    loss /= len(test_data)

    return loss
